import os
import random

images_path = '../../Mask_Face/Data/3classes/images/'   # 数据集图片目录
txt_name_path = 'data/3classes/'  # 存放数据集文件名的目录

train_name_txt = txt_name_path + 'train.txt'
val_name_txt = txt_name_path + 'val.txt'

train_proportion = 0.95 # 训练集所占数据集的比例

images_list = os.listdir(images_path) # 数据集所有图像名做成的列表

data_length = len(images_list) # 所有数据集大小

random.shuffle(images_list) # 将数据集所有图像名做成的列表打乱

train_num = int(data_length * train_proportion) # 训练集数

f_train = open(train_name_txt, 'w+')
f_val = open(val_name_txt, 'w+')

for index in range(data_length):
    if index < train_num:
        train_context = images_path + images_list[index] + '\n'
        f_train.write(train_context)
    else:
        val_context = images_path + images_list[index] + '\n'
        f_val.write(val_context)

f_train.close()
f_val.close()

